{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d89770a0",
   "metadata": {
    "id": "78dc1a9c-d009-407e-836d-84cf85936ade"
   },
   "source": [
    "# Safeguarding AI Virtual Assistant for Customer Service with NVIDIA NeMo Guardrails\n",
    "\n",
    "AI agents present a significant opportunity for businesses to scale and elevate customer service and support interactions. By automating routine inquiries and enhancing response times, these agents improve efficiency and customer satisfaction, helping organizations stay competitive. \n",
    "\n",
    "However, alongside these benefits, AI agents come with risks. Large language models (LLMs) are vulnerable to generating inappropriate or off-topic content and can be susceptible to jailbreak attacks. To fully realize the potential of generative AI in customer service, it is essential to implement robust AI safety and security measures.\n",
    "\n",
    "This tutorial equips AI builders with actionable steps to integrate essential safeguards into AI agents for customer service applications. We’ll explore how to integrate AI safeguard NIM microservices using NeMo Guardrails to build guardrail configurations that ensure your AI agent can identify and mitigate unsafe interactions in real time. Then, we’ll take it a step further by connecting these capabilities to the sophisticated agentic workflows outlined in the NVIDIA AI Blueprint for AI virtual assistants. By the end, you’ll have a clear understanding of how to create a scalable and secure AI assistant tailored to your brand’s unique needs. \n",
    "\n",
    "Figure 1 details the architecture workflow of integrating **[NeMo Guardrails](https://docs.nvidia.com/nemo/guardrails/index.html)** and safeguarding **[NIM microservices](https://developer.nvidia.com/nim)** in the **[NVIDIA AI Blueprint for virtual assistants](https://build.nvidia.com/nvidia/ai-virtual-assistant-for-customer-service)**.\n",
    "\n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "### Docker compose\n",
    "\n",
    "#### System requirements\n",
    "\n",
    "Ubuntu 20.04 or 22.04 based machine, with sudo privileges\n",
    "\n",
    "Install software requirements\n",
    "- Install Docker Engine and Docker Compose. Refer to the instructions for Ubuntu.\n",
    "- Ensure the Docker Compose plugin version is 2.29.1 or higher.\n",
    "- Run docker compose version to confirm.\n",
    "- Refer to Install the Compose plugin in the Docker documentation for more information.\n",
    "- To configure Docker for GPU-accelerated containers, install the NVIDIA Container Toolkit.\n",
    "- Install git\n",
    "\n",
    "By default the provided configurations use GPU optimized databases such as Milvus.\n",
    "\n",
    "\n",
    "### Safety NIM Microservices\n",
    "\n",
    "#### Compute Requirements\n",
    "If you are going to deploy the **[Safety NIM Miccroservices](https://docs.nvidia.com/_preview?_cms.db.previewId=00000194-6b79-d6bc-a7b5-6bfb99b10000&_fields=true&_mainObjectId=&_date=#nemoguard)** using the downloadable containers from **[NVIDIA NGC](https://registry.ngc.nvidia.com/)**, then there might be a need for higher compute\n",
    "- minimum of 4xH100, or 4xA100\n",
    "\n",
    "If the Safety NIM Microservices are used using the build.nvidia.com endpoints, there is no need for additional compute\n",
    "- 1xH100, 1xA100"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95539ada",
   "metadata": {},
   "source": [
    "## Getting API Keys - Very Important\n",
    "\n",
    "To run the pipeline you need to obtain an API key from NVIDIA. These will be needed in a later step to Set up the environment file.\n",
    "\n",
    "- Required API Keys: These APIs are required by the pipeline to execute LLM queries.\n",
    "\n",
    "- NVIDIA API Catalog\n",
    "  1. Navigate to **[NVIDIA API Catalog](https://build.nvidia.com/explore/discover)**.\n",
    "  2. Select any model, such as llama-3.3-70b-instruct.\n",
    "  3. On the right panel above the sample code snippet, click on \"Get API Key\". This will prompt you to log in if you have not already.\n",
    "\n",
    "NOTE: The API key starts with nvapi- and ends with a 32-character string. You can also generate an API key from the user settings page in NGC (https://ngc.nvidia.com/)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1ccc43c",
   "metadata": {},
   "source": [
    "Export API Keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8303812e",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "044596c7-801c-4bfa-b9f6-e940cab81993",
    "outputId": "8f4c4700-9cc2-47f4-e122-3b1ae38b8ab3"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "NVIDIA_API_KEY = input(\"Please enter your NVIDIA API key (nvapi-): \")\n",
    "NGC_API_KEY = NVIDIA_API_KEY\n",
    "os.environ[\"NVIDIA_API_KEY\"] = NVIDIA_API_KEY\n",
    "os.environ[\"NGC_CLI_API_KEY\"] = NGC_API_KEY\n",
    "os.environ[\"NGC_API_KEY\"] = NGC_API_KEY"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f68f3b85",
   "metadata": {
    "id": "35890619-980d-4176-bbcb-c696a411d83f"
   },
   "source": [
    "# Step 1: Deploying the NIM Blueprint "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8088aee0",
   "metadata": {},
   "source": [
    "Open the jupyter notebook  **[./ai-virtual-assistant/deploy/ai_virtual_assistant_notebook.ipynb](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/blob/main/deploy/ai_virtual_assistant_notebook.ipynb)** and run through the cells (Shift + Enter) to start the **[NIM blueprint](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant)** and all the necessary docker containers. Following the same notebook, run the **[./ai-virtual-assistant/notebooks/ingest_data.ipynb](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/blob/main/notebooks/ingest_data.ipynb)** to ingest the structured and unstructured data types."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0146a8b",
   "metadata": {},
   "source": [
    "# Step 2: Download the NeMo Guardrails Toolkit \n",
    "\n",
    "Start by cloning the NeMo Guardrails repository:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1578b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/NVIDIA/NeMo-Guardrails.git nemoguardrails"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3646875",
   "metadata": {},
   "source": [
    "Make sure that the notebook is operating from the `ai-virtual-assistant` directory. If it's not, it changes to that directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4aa96750",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "current_path = os.getcwd()\n",
    "last_part = os.path.basename(current_path)\n",
    "\n",
    "if os.path.basename(os.getcwd()) != \"ai-virtual-assistant\":\n",
    "    os.chdir(\"ai-virtual-assistant\")\n",
    "\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc186b81",
   "metadata": {
    "id": "02e1eeec"
   },
   "source": [
    "We login into the NGC catalogue."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cb79f78",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 106
    },
    "id": "ded8d982",
    "outputId": "14a0124b-1ae5-4136-b4f6-8d0ddb7b9843"
   },
   "outputs": [],
   "source": [
    "!docker login nvcr.io -u '$oauthtoken' -p $NGC_API_KEY"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d768ffcc",
   "metadata": {
    "id": "cdd03972"
   },
   "source": [
    "## Build the NeMo Guardrails with Docker\n",
    "\n",
    "First setup the `nemoguardrails.yaml` file for NeMo Guardrails and then launch the **[container](https://docs.nvidia.com/nemo/guardrails/user_guides/advanced/using-docker.html)** by using the following command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2345910e",
   "metadata": {
    "id": "a7387b26"
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "docker compose -f deploy/compose/nemoguardrails.yaml up -d"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "098800af",
   "metadata": {},
   "source": [
    "Before running the nemoguardrails server, we need to add the guardrails configuration. Let's deploy the safety NIMs and integrate it with NeMo Guardrails"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0a4ba76",
   "metadata": {},
   "source": [
    "## Step 3:  Deploying the Safety NIMs\n",
    "\n",
    "This tutorial equips AI builders with actionable steps to integrate essential safeguards into AI agents for customer service applications. It demonstrates how to leverage NVIDIA NeMo Guardrails, a scalable rail orchestration platform, including the following three new AI safeguard models offered as NVIDIA NIM microservices:\n",
    "\n",
    "**[Llama 3.1 NemoGuard 8B ContentSafety](https://build.nvidia.com/nvidia/llama-3_1-nemoguard-8b-content-safety)** for safeguarding input prompts and output responses in AI interactions, ensuring AI systems align with ethical standards. Llama 3.1 NemoGuard 8B ContentSafety is trained on the Aegis Content Safety Dataset including 35,000 human annotated AI safety data samples. It features explicit response labels curated through an automated process using an ensemble of LLM-as-a-judge across NVIDIA-developed and open community LLMs.\n",
    "\n",
    "**[Llama 3.1 NemoGuard 8B TopicControl](https://build.nvidia.com/nvidia/llama-3_1-nemoguard-8b-topic-control)** for keeping conversations focused on approved topics, avoiding derailment or inappropriate content. Llama 3.1 NemoGuard 8B TopicControl is fine-tuned on synthetic data to maintain context and enforce boundaries consistently throughout entire AI conversations. \n",
    "\n",
    "**[NemoGuard JailbreakDetect](https://build.nvidia.com/nvidia/nemoguard-jailbreak-detect)** for protection against jailbreak attempts, helping to maintain AI integrity in adversarial scenarios. NemoGuard JailbreakDetect is an LLM jailbreak classification model trained on a dataset of 17,000 known challenging and successful jailbreaks, built in part using NVIDIA Garak, an open-source toolkit for LLM and application vulnerability scanning developed by the NVIDIA Research team.\n",
    "\n",
    "Each of the Safety NIMs can be deployed either as a downloadable container or via the endpoint\n",
    "\n",
    "Let us see how to deploy NIMs as downloadable containers\n",
    "\n",
    "### Llama 3.1 NemoGuard 8B ContentSafety\n",
    "\n",
    "The Llama 3.1 NemoGuard 8B ContentSafety NIM follows a set of 42 Safety hazard categories with data distributions including annotations for jailbreak data, diverse cultural and geographical AI content safety like Hazards and Self Harm. Custom and novel safety risk categories and policy can also be provided in the instruction for the model to categorize using the novel taxonomy and policy. The model detects if the user input and/or the LLM response are safe or unsafe, and if unsafe, gives the violated category in the response. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b067b026",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir safety-nims"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "837f23cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile safety-nims/content-safety.sh\n",
    "export NGC_API_KEY=<your NGC personal key>\n",
    "export NIM_IMAGE=<Path to latest NIM docker container>\n",
    "export MODEL_NAME=\"llama-3.1-nemoguard-8b-content-safety\"\n",
    "docker pull $NIM_IMAGE\n",
    "\n",
    "docker run -it --name=$MODEL_NAME \\\n",
    "    --gpus=\"device=0\" --runtime=nvidia \\\n",
    "    -e NGC_API_KEY=\"$NGC_API_KEY\" \\\n",
    "    -e NIM_SERVED_MODEL_NAME=$MODEL_NAME \\\n",
    "    -e NIM_CUSTOM_MODEL_NAME=$MODEL_NAME \\\n",
    "    -v $LOCAL_NIM_CACHE:\"/opt/nim/.cache/\" \\\n",
    "    -u $(id -u) \\\n",
    "    -p 8123:8000 \\\n",
    "    $NIM_IMAGE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f81edfb6",
   "metadata": {},
   "source": [
    "On your teminal (irrespective of running locally or on VM) run the `ai-virtual-assistant/safety-nims/content-safety.sh` to deploy the NIM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d86c72d6",
   "metadata": {},
   "source": [
    "### Llama 3.1 NemoGuard 8B TopicControl\n",
    "\n",
    "The Llama 3.1 NemoGuard 8B TopicControl NIM can be used for topical and dialogue moderation of user prompts in human-assistant interactions being designed for task-oriented dialogue agents and custom policy-based moderation. Given a system instruction (also called topical instruction, i.e. specifying which topics are allowed and disallowed) and a conversation history ending with the last user prompt, the model returns a binary response that flags if the user message respects the system instruction, (i.e. message is on-topic or a distractor/off-topic). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67e9ec9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile safety-nims/topic-control.sh\n",
    "export NGC_API_KEY=<your NGC personal key>\n",
    "export NIM_IMAGE=<Path to latest NIM docker container>\n",
    "export MODEL_NAME=\"llama-3.1-nemoguard-8b-topic-control\"\n",
    "docker pull $NIM_IMAGE\n",
    "\n",
    "docker run -it --name=$MODEL_NAME \\\n",
    "    --gpus=\"device=1\" --runtime=nvidia \\\n",
    "    -e NGC_API_KEY=\"$NGC_API_KEY\" \\\n",
    "    -e NIM_SERVED_MODEL_NAME=$MODEL_NAME \\\n",
    "    -e NIM_CUSTOM_MODEL_NAME=$MODEL_NAME \\\n",
    "    -v $LOCAL_NIM_CACHE:\"/opt/nim/.cache/\" \\\n",
    "    -u $(id -u) \\\n",
    "    -p 8124:8000 \\\n",
    "    $NIM_IMAGE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dc5b5d7",
   "metadata": {},
   "source": [
    "On your teminal (irrespective of running locally or on VM) run the `ai-virtual-assistant/safety-nims/topic-control.sh` to deploy the NIM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c66f5c70",
   "metadata": {},
   "source": [
    "### NemoGuard JailbreakDetect\n",
    "The NemoGuard JailbreakDetect NIM was developed to detect attempts to jailbreak large language models. The Jailbreak detection model uses Snowflake-arctic-embed-m embeddings. It is trained on the combination of three open datasets, mixed together, de-duplicated, and reviewed for data quality. Jailbreak data was augmented with the use of garak."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14423b64",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile safety-nims/jailbreak-detect.sh\n",
    "export NGC_API_KEY=<your NGC personal key>\n",
    "export NIM_IMAGE=<Path to latest NIM docker container>\n",
    "export MODEL_NAME='ardennes-jailbreak-arctic'\n",
    "docker pull $NIM_IMAGE\n",
    "\n",
    "docker run -it --name=$MODEL_NAME \\\n",
    "    --gpus=\"device=1\" --runtime=nvidia \\\n",
    "    -e NGC_API_KEY=\"$NGC_API_KEY\" \\\n",
    "    -v $LOCAL_NIM_CACHE:\"/opt/nim/.cache/\" \\\n",
    "    -u $(id -u) \\\n",
    "    -p 8125:8000 \\\n",
    "    $NIM_IMAGE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0e9572c",
   "metadata": {},
   "source": [
    "On your teminal (irrespective of running locally or on VM) run the `ai-virtual-assistant/safety-nims/jailbreak-detect.sh` to deploy the NIM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b899ca5",
   "metadata": {},
   "source": [
    "While building the guardrails configuration, integrate the three safeguard NIM microservices, start with creating the config directory: \n",
    "\n",
    "\n",
    "```\n",
    "├── config\n",
    "│   ├── config.yml\n",
    "│   ├── prompts.yml\n",
    "\n",
    "```\n",
    "\n",
    "Now, add each configuration option one by one, starting with the models in the config.yml file:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db7637ad",
   "metadata": {},
   "source": [
    "## Step 4: Building the NeMo Guardrails configuration with the Safety NIMs\n",
    "### 1. Creating the `config.yml` file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "991bef0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20df2406",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile config/config.yml\n",
    "models:\n",
    "  - type: main\n",
    "    engine: nim\n",
    "    model: meta/llama-3.1-70b-instruct\n",
    "\n",
    "  - type: \"llama-3.1-nemoguard-8b-content-safety\"\n",
    "    engine: nim\n",
    "    parameters:\n",
    "      base_url: \"http://localhost:8123/v1\"\n",
    "      model_name: \"llama-3.1-nemoguard-8b-content-safety\"\n",
    "  - type: topic_control\n",
    "    engine: nim\n",
    "    parameters:\n",
    "      base_url: \"http://localhost:8124/v1/\"\n",
    "      model_name: \"llama-3.1-nemoguard-8b-topic-control\"\n",
    "\n",
    "rails:\n",
    "  config:\n",
    "    jailbreak_detection:\n",
    "      server_endpoint: \"\"\n",
    "      embedding: \"Snowflake/snowflake-arctic-embed-m-long\"\n",
    "  input:\n",
    "    flows:\n",
    "      - content safety check input $model=llama-3.1-nemoguard-8b-content-safety\n",
    "      - topic safety check input $model=topic_control\n",
    "      - jailbreak detection model\n",
    "  output:\n",
    "    flows:\n",
    "      - content safety check output $model=llama-3.1-nemoguard-8b-content-safety"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda5a456",
   "metadata": {},
   "source": [
    "Both the Content safety NIM and the Topic control NIM come with a set of policy compliances which can be added to the guardrails configuration using the `prompts.yml` file\n",
    "\n",
    "### 2. Creating the `prompts.yml` file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "040ed668",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile config/prompts.yml\n",
    "# These are the default prompts released by Meta, except for policy O7, which was added to address direct insults.\n",
    "prompts:\n",
    "\n",
    "  - task: content_safety_check_input $model=llama_3.1_aegis_guard_2.0\n",
    "    content: |\n",
    "      Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below.\n",
    "\n",
    "      <BEGIN UNSAFE CONTENT CATEGORIES>\n",
    "      S1: Violence.\n",
    "      S2: Sexual.\n",
    "      S3: Criminal Planning/Confessions.\n",
    "      S4: Guns and Illegal Weapons.\n",
    "      S5: Controlled/Regulated Substances.\n",
    "      S6: Suicide and Self Harm.\n",
    "      S7: Sexual (minor).\n",
    "      S8: Hate/Identity Hate.\n",
    "      S9: PII/Privacy.\n",
    "      S10: Harassment.\n",
    "      S11: Threat.\n",
    "      S12: Profanity.\n",
    "      S13: Needs Caution.\n",
    "      S14: Other.\n",
    "      S15: Manipulation.\n",
    "      S16: Fraud/Deception.\n",
    "      S17: Malware.\n",
    "      S18: High Risk Gov Decision Making.\n",
    "      S19: Political/Misinformation/Conspiracy.\n",
    "      S20: Copyright/Trademark/Plagiarism.\n",
    "      S21: Unauthorized Advice.\n",
    "      S22: Illegal Activity.\n",
    "      S23: Immoral/Unethical.\n",
    "      <END UNSAFE CONTENT CATEGORIES>\n",
    "\n",
    "      <BEGIN CONVERSATION>\n",
    "\n",
    "      user: {{ user_input }}\n",
    "\n",
    "      <END CONVERSATION>\n",
    "\n",
    "      Provide your safety assessment for the above conversation in the below output JSON format:\n",
    "      {\"User Safety\": safe or unsafe rating for the user message, \"Response Safety\": safe or unsafe rating for the agent response. Omit if no agent response present. \"Safety Categories\": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.}\n",
    "\n",
    "      Do not include anything other than the output JSON in your response.\n",
    "      Output JSON:\n",
    "    output_parser: is_content_safe_aegis_input\n",
    "    max_tokens: 50\n",
    "\n",
    "  - task: content_safety_check_output $model=llama_3.1_aegis_guard_2.0\n",
    "    content: |\n",
    "      Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below.\n",
    "\n",
    "      <BEGIN UNSAFE CONTENT CATEGORIES>\n",
    "      S1: Violence.\n",
    "      S2: Sexual.\n",
    "      S3: Criminal Planning/Confessions.\n",
    "      S4: Guns and Illegal Weapons.\n",
    "      S5: Controlled/Regulated Substances.\n",
    "      S6: Suicide and Self Harm.\n",
    "      S7: Sexual (minor).\n",
    "      S8: Hate/Identity Hate.\n",
    "      S9: PII/Privacy.\n",
    "      S10: Harassment.\n",
    "      S11: Threat.\n",
    "      S12: Profanity.\n",
    "      S13: Needs Caution.\n",
    "      S14: Other.\n",
    "      S15: Manipulation.\n",
    "      S16: Fraud/Deception.\n",
    "      S17: Malware.\n",
    "      S18: High Risk Gov Decision Making.\n",
    "      S19: Political/Misinformation/Conspiracy.\n",
    "      S20: Copyright/Trademark/Plagiarism.\n",
    "      S21: Unauthorized Advice.\n",
    "      S22: Illegal Activity.\n",
    "      S23: Immoral/Unethical.\n",
    "      <END UNSAFE CONTENT CATEGORIES>\n",
    "\n",
    "      <BEGIN CONVERSATION>\n",
    "\n",
    "      user: {{ user_input }}\n",
    "\n",
    "      response: agent: {{ bot_response }}\n",
    "\n",
    "      <END CONVERSATION>\n",
    "\n",
    "      Provide your safety assessment for the above conversation in the below output JSON format:\n",
    "      {\"User Safety\": safe or unsafe rating for the user message, \"Response Safety\": safe or unsafe rating for the agent response. Omit if no agent response present. \"Safety Categories\": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.}\n",
    "\n",
    "      Do not include anything other than the output JSON in your response.\n",
    "      Output JSON:\n",
    "    output_parser: is_content_safe_aegis_output\n",
    "    max_tokens: 50\n",
    "\n",
    "  - task: content_safety_check_input $model=llama_guard\n",
    "    content: |\n",
    "      <s>[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.\n",
    "        \n",
    "  - task: topic_safety_check_input $model=topic_control\n",
    "    content: |\n",
    "      You are to act as a customer service agent, providing users with factual information in accordance to the knowledge base. Your role is to ensure that you respond only to relevant queries and adhere to the following guidelines\n",
    "\n",
    "      Guidelines for the user messages:\n",
    "      - Do not answer questions related to personal opinions or advice on user's order, future recommendations\n",
    "      - Do not provide any information on non-company products or services.\n",
    "      - Do not answer enquiries unrelated to the company policies.\n",
    "      - Do not answer questions asking for personal details about the agent or its creators.\n",
    "      - Do not answer questions about sensitive topics related to politics, religion, or other sensitive subjects.\n",
    "      - If a user asks topics irrelevant to the company's customer service relations, politely redirect the conversation or end the interaction.\n",
    "      - Your responses should be professional, accurate, and compliant with customer relations guidelines, focusing solely on providing transparent, up-to-date information about the company that is already publicly available."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdaf1f39",
   "metadata": {},
   "source": [
    "## Step5: Wrapping the guardrails configuration around the agentic system"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be30a152",
   "metadata": {},
   "source": [
    "With the configuration complete, you could use it as is to apply guardrails to a general-purpose conversational AI by interfacing with the NeMo Guardrails server through its API. The assistant or agent from the NIM Blueprint performs multiple tasks, a few including RAG, checking if the user is compliant with the return policy, and thereby updating the return option, getting the user's purchase history. \n",
    "\n",
    "Start with adding chains to the following agent components\n",
    "- `src/analytics/main.py`\n",
    "- `src/agent/utils.py`\n",
    "- `src/agent/main.py`\n",
    "\n",
    "#### 1. Analytics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24cbef9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile src/analytics/main.py\n",
    "# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n",
    "# SPDX-License-Identifier: Apache-2.0\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "import os\n",
    "import logging\n",
    "from datetime import datetime\n",
    "from enum import Enum\n",
    "from typing import Annotated, Generator, Literal, Sequence, TypedDict\n",
    "\n",
    "from langchain_core.messages import BaseMessage\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts.chat import ChatPromptTemplate\n",
    "from langchain_core.runnables import RunnableConfig\n",
    "from pydantic import BaseModel, Field\n",
    "from src.analytics.datastore.session_manager import SessionManager\n",
    "from src.common.utils import get_config, get_llm, get_prompts\n",
    "from nemoguardrails import RailsConfig\n",
    "from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "prompts = get_prompts()\n",
    "\n",
    "# TODO get the default_kwargs from the Agent Server API\n",
    "default_llm_kwargs = {\"temperature\": 0, \"top_p\": 0.7, \"max_tokens\": 1024}\n",
    "\n",
    "# Initialize persist_data to determine whether data should be stored in the database.\n",
    "persist_data = os.environ.get(\"PERSIST_DATA\", \"true\").lower() == \"true\"\n",
    "\n",
    "# Initialize session manager during startup\n",
    "session_manager = None\n",
    "try:\n",
    "    session_manager = SessionManager()\n",
    "except Exception as e:\n",
    "    logger.info(f\"Failed to connect to DB during init, due to exception {e}\")\n",
    "    \n",
    "# Initialize  guardrails configuration\n",
    "rail_config = RailsConfig.from_path(\"./config\")\n",
    "guardrails = RunnableRails(rail_config, input_key=\"query\", output_key='content')\n",
    "    \n",
    "    \n",
    "def get_database():\n",
    "    \"\"\"\n",
    "    Connect to the database.\n",
    "    \"\"\"\n",
    "    global session_manager\n",
    "    try:\n",
    "        if not session_manager:\n",
    "            session_manager = SessionManager()\n",
    "\n",
    "        return session_manager\n",
    "    except Exception as e:\n",
    "        logger.info(f\"Error connecting to database: {e}\")\n",
    "        return None\n",
    "\n",
    "\n",
    "def generate_summary(conversation_history):\n",
    "    \"\"\"\n",
    "    Generate a summary of the conversation.\n",
    "\n",
    "    Parameters:\n",
    "        conversation_history (List): The conversation text.\n",
    "\n",
    "    Returns:\n",
    "        str: A summary of the conversation.\n",
    "    \"\"\"\n",
    "    logger.info(f\"conversation history: {conversation_history}\")\n",
    "    llm = get_llm(**default_llm_kwargs)\n",
    "    prompt = prompts.get(\"summary_prompt\", \"\")\n",
    "    for turn in conversation_history:\n",
    "        prompt += f\"{turn['role']}: {turn['content']}\\n\"\n",
    "\n",
    "    prompt += \"\\n\\nSummary: \"\n",
    "    \n",
    "    # Apply guardrails to the chain\n",
    "    chain_with_guardrails = guardrails | llm\n",
    "    response = chain_with_guardrails.invoke({\"query\": prompt})\n",
    "\n",
    "    return response.content\n",
    "\n",
    "\n",
    "def generate_session_summary(session_id):\n",
    "    # TODO: Check for corner cases like when session_id does not exist\n",
    "    session_manager = get_database()\n",
    "\n",
    "    # Check if summary already exists in database\n",
    "    session_info = session_manager.get_session_summary_and_sentiment(session_id)\n",
    "    if session_info and session_info.get(\"summary\", None):\n",
    "        return session_info\n",
    "\n",
    "    # Generate summary and session info\n",
    "    conversation_history = session_manager.get_conversation(session_id)\n",
    "    summary = generate_summary(conversation_history)\n",
    "    sentiment = generate_sentiment(conversation_history)\n",
    "\n",
    "    if persist_data:\n",
    "        # Save the summary and sentiment in database\n",
    "        session_manager.save_summary_and_sentiment(\n",
    "            session_id,\n",
    "            {\n",
    "                \"summary\": summary,\n",
    "                \"sentiment\": sentiment,\n",
    "                \"start_time\": conversation_history[0].get(\"timestamp\", 0),\n",
    "                \"end_time\": conversation_history[-1].get(\"timestamp\", 0),\n",
    "            }\n",
    "        )\n",
    "    return {\n",
    "        \"summary\": summary,\n",
    "        \"sentiment\": sentiment,\n",
    "        \"start_time\": datetime.fromtimestamp(\n",
    "            float(conversation_history[0].get(\"timestamp\", 0))\n",
    "        ),\n",
    "        \"end_time\": datetime.fromtimestamp(\n",
    "            float(conversation_history[-1].get(\"timestamp\", 0))\n",
    "        ),\n",
    "    }\n",
    "\n",
    "\n",
    "def fetch_user_conversation(user_id, start_time=None, end_time=None):\n",
    "    \"\"\"\n",
    "    Fetch a user's conversation from the database.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        # TODO: Use start time and end time to filter the data\n",
    "        session_manager = get_database()\n",
    "        conversations = session_manager.list_sessions_for_user(user_id)\n",
    "        logger.info(f\"Conversation: {conversations}\")\n",
    "        return conversations\n",
    "    except Exception as e:\n",
    "        logger.error(f\"Error fetching conversation: {e}\")\n",
    "        return None\n",
    "\n",
    "\n",
    "def generate_sentiment(conversation_history):\n",
    "    # Define an Enum for the sentiment values\n",
    "    class SentimentEnum(str, Enum):\n",
    "        POSITIVE = \"positive\"\n",
    "        NEUTRAL = \"neutral\"\n",
    "        NEGATIVE = \"negative\"\n",
    "\n",
    "    # Define the Pydantic model using the Enum\n",
    "    class Sentiment(BaseModel):\n",
    "        \"\"\"Sentiment for conversation.\"\"\"\n",
    "\n",
    "        sentiment: SentimentEnum = Field(\n",
    "            description=\"Relevant value 'positive', 'neutral' or 'negative'\"\n",
    "        )\n",
    "\n",
    "    logger.info(\"Finding sentiment for conversation\")\n",
    "    llm = get_llm(**default_llm_kwargs)\n",
    "    prompt = prompts.get(\"sentiment_prompt\", \"\")\n",
    "    for turn in conversation_history:\n",
    "        prompt += f\"{turn['role']}: {turn['content']}\\n\"\n",
    "\n",
    "    llm_with_tool = llm.with_structured_output(Sentiment)\n",
    "\n",
    "    # Apply guardrails to the chain\n",
    "    chain_with_guardrails = guardrails | llm_with_tool\n",
    "    response = chain_with_guardrails.invoke({\"query\": prompt})\n",
    "    \n",
    "    sentiment = response.content.sentiment.value\n",
    "    logger.info(f\"Conversation classified as {sentiment}\")\n",
    "    return sentiment\n",
    "\n",
    "\n",
    "def generate_sentiment_for_query(session_id):\n",
    "    \"\"\"Generate sentiment for user query and assistant response\n",
    "    \"\"\"\n",
    "\n",
    "    logger.info(\"Fetching sentiment for queries\")\n",
    "    # Check if the sentiment is already identified in database, if yes return that\n",
    "    session_manager = get_database()\n",
    "\n",
    "    session_info = session_manager.get_query_sentiment(session_id)\n",
    "\n",
    "    if session_info and session_info.get(\"messages\", None):\n",
    "        return {\n",
    "        \"messages\": session_info.get(\"messages\"),\n",
    "            \"session_info\": {\n",
    "                \"session_id\": session_id,\n",
    "                \"start_time\": session_info.get(\"start_time\"),\n",
    "                \"end_time\": session_info.get(\"start_time\"),\n",
    "            },\n",
    "        }\n",
    "\n",
    "    class SentimentEnum(str, Enum):\n",
    "        POSITIVE = \"positive\"\n",
    "        NEUTRAL = \"neutral\"\n",
    "        NEGATIVE = \"negative\"\n",
    "\n",
    "    # Define the Pydantic model using the Enum\n",
    "    class Sentiment(BaseModel):\n",
    "        \"\"\"Sentiment for conversation.\"\"\"\n",
    "\n",
    "        sentiment: SentimentEnum = Field(\n",
    "            description=\"Relevant value 'positive', 'neutral' or 'negative'\"\n",
    "        )\n",
    "\n",
    "\n",
    "    # Generate summary and session info\n",
    "    conversation_history = session_manager.get_conversation(session_id)\n",
    "    logger.info(f\"Conversation history: {conversation_history}\")\n",
    "\n",
    "    logger.info(\"Finding sentiment for conversation\")\n",
    "    llm = get_llm(**default_llm_kwargs)\n",
    "\n",
    "    llm_with_tool = llm.with_structured_output(Sentiment)\n",
    "    \n",
    "    # Apply guardrails to the chain\n",
    "    chain_with_guardrails = guardrails | llm_with_tool\n",
    "\n",
    "    messages = []\n",
    "    # TODO: parallize this operation for faster response\n",
    "    # Find sentiment for individual query and assistant response\n",
    "    for turn in conversation_history:\n",
    "        prompt = prompts.get(\"query_sentiment_prompt\", \"\")\n",
    "        prompt += f\"{turn['role']}: {turn['content']}\\n\"\n",
    "\n",
    "        response = chain_with_guardrails.invoke({\"query\": prompt})\n",
    "        sentiment = response.content.sentiment.value\n",
    "        messages.append({\n",
    "            \"role\": turn[\"role\"],\n",
    "            \"content\": turn[\"content\"],\n",
    "            \"sentiment\": sentiment,\n",
    "        })\n",
    "\n",
    "    session_info = {\n",
    "        \"messages\": messages,\n",
    "        \"start_time\": conversation_history[0].get(\"timestamp\", 0),\n",
    "        \"end_time\": conversation_history[-1].get(\"timestamp\", 0),\n",
    "    }\n",
    "    if persist_data:\n",
    "        # Save information before sending it to user\n",
    "        session_manager.save_query_sentiment(session_id, session_info)\n",
    "    return {\n",
    "        \"messages\": messages,\n",
    "            \"session_info\": {\n",
    "                \"session_id\": session_id,\n",
    "                \"start_time\": datetime.fromtimestamp(\n",
    "                    float(conversation_history[0].get(\"timestamp\", 0))\n",
    "                ),\n",
    "                \"end_time\": datetime.fromtimestamp(\n",
    "                    float(conversation_history[-1].get(\"timestamp\", 0))\n",
    "                ),\n",
    "            },\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53cec8f7",
   "metadata": {},
   "source": [
    "#### 2. Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b45f256",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile src/agent/utils.py\n",
    "# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n",
    "# SPDX-License-Identifier: Apache-2.0\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "import re\n",
    "import os\n",
    "import logging\n",
    "from typing import Dict\n",
    "from pydantic import BaseModel, Field\n",
    "from urllib.parse import urlparse\n",
    "\n",
    "import requests\n",
    "\n",
    "from psycopg_pool import AsyncConnectionPool\n",
    "from psycopg.rows import dict_row\n",
    "import psycopg2\n",
    "\n",
    "from src.common.utils import get_llm, get_prompts, get_config\n",
    "from langchain_core.prompts.chat import ChatPromptTemplate\n",
    "from langchain_core.messages import HumanMessage\n",
    "from langchain_core.messages import ToolMessage\n",
    "from langchain_core.runnables import RunnableLambda\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver\n",
    "from langgraph.prebuilt import ToolNode\n",
    "from nemoguardrails import RailsConfig\n",
    "from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails\n",
    "\n",
    "prompts = get_prompts()\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "# TODO get the default_kwargs from the Agent Server API\n",
    "default_llm_kwargs = {\"temperature\": 0, \"top_p\": 0.7, \"max_tokens\": 1024}\n",
    "\n",
    "canonical_rag_url = os.getenv('CANONICAL_RAG_URL', 'http://unstructured-retriever:8081')\n",
    "canonical_rag_search = f\"{canonical_rag_url}/search\"\n",
    "\n",
    "# Initialize  guardrails configuration\n",
    "rail_config = RailsConfig.from_path(\"./config\")\n",
    "guardrails = RunnableRails(rail_config, input_key=\"query\", output_key='content')\n",
    "\n",
    "def get_product_name(messages, product_list) -> Dict:\n",
    "    \"\"\"Given the user message and list of product find list of items which user might be talking about\"\"\"\n",
    "\n",
    "    # First check product name in query\n",
    "    # If it's not in query, check in conversation\n",
    "    # Once the product name is known we will search for product name from database\n",
    "    # We will return product name from list and actual name detected.\n",
    "\n",
    "    llm = get_llm(**default_llm_kwargs)\n",
    "\n",
    "    class Product(BaseModel):\n",
    "        name: str = Field(..., description=\"Name of the product talked about.\")\n",
    "\n",
    "    prompt_text = prompts.get(\"get_product_name\")[\"base_prompt\"]\n",
    "    prompt = ChatPromptTemplate.from_messages(\n",
    "        [\n",
    "            (\"system\", prompt_text),\n",
    "        ]\n",
    "    )\n",
    "    llm = llm.with_structured_output(Product)\n",
    "\n",
    "    chain = prompt | llm\n",
    "    # Adding guardrails to the chain\n",
    "    chain_with_guardrails = guardrails | chain\n",
    "    # query to be used for document retrieval\n",
    "    # Get the last human message instead of messages[-2]\n",
    "    last_human_message = next((m.content for m in reversed(messages) if isinstance(m, HumanMessage)), None)\n",
    "    response = chain_with_guardrails.invoke({\"query\": last_human_message})\n",
    "\n",
    "    product_name = response.content.name\n",
    "\n",
    "    # Check if product name is in query\n",
    "    if product_name == 'null':\n",
    "\n",
    "        # Check for produt name in user conversation\n",
    "        fallback_prompt_text = prompts.get(\"get_product_name\")[\"fallback_prompt\"]\n",
    "        prompt = ChatPromptTemplate.from_messages(\n",
    "            [\n",
    "                (\"system\", fallback_prompt_text),\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        llm = get_llm(**default_llm_kwargs)\n",
    "        llm = llm.with_structured_output(Product)\n",
    "\n",
    "        chain = prompt | llm\n",
    "        # Adding guardrails to the chain\n",
    "        chain_with_guardrails = guardrails | chain\n",
    "        # query to be used for document retrieval\n",
    "        response = chain.invoke({\"query\": messages})\n",
    "\n",
    "        product_name = response.content.name\n",
    "    # Check if it's partial name exists or not\n",
    "    if product_name == 'null':\n",
    "        return {}\n",
    "\n",
    "    def filter_products_by_name(name, products):\n",
    "        # TODO: Replace this by llm call to check if that can take care of cases like\n",
    "        # spelling mistakes or words which are seperated\n",
    "        # TODO: Directly make sql query with wildcard\n",
    "        name_lower = name.lower()\n",
    "\n",
    "        # Check for exact match first\n",
    "        exact_match = [product for product in products if product.lower() == name_lower]\n",
    "        if exact_match:\n",
    "            return exact_match\n",
    "\n",
    "        # If no exact match, fall back to partial matches\n",
    "        name_parts = [part for part in re.split(r'\\s+', name_lower) if part.lower() != 'nvidia']\n",
    "        # Match only if all parts of the search term are found in the product name\n",
    "        matching_products = [\n",
    "            product for product in products\n",
    "            if all(part in product.lower() for part in name_parts if part)\n",
    "        ]\n",
    "\n",
    "        return matching_products\n",
    "\n",
    "    matching_products = filter_products_by_name(product_name, product_list)\n",
    "\n",
    "    return {\n",
    "        \"product_in_query\": product_name,\n",
    "        \"products_from_purchase\": list(set([product for product in matching_products]))\n",
    "    }\n",
    "\n",
    "\n",
    "def handle_tool_error(state) -> dict:\n",
    "    error = state.get(\"error\")\n",
    "    tool_calls = state[\"messages\"][-1].tool_calls\n",
    "    return {\n",
    "        \"messages\": [\n",
    "            ToolMessage(\n",
    "                content=f\"Error: {repr(error)}\\n please fix your mistakes.\",\n",
    "                tool_call_id=tc[\"id\"],\n",
    "            )\n",
    "            for tc in tool_calls\n",
    "        ]\n",
    "    }\n",
    "\n",
    "\n",
    "def create_tool_node_with_fallback(tools: list) -> dict:\n",
    "    return ToolNode(tools).with_fallbacks(\n",
    "        [RunnableLambda(handle_tool_error)], exception_key=\"error\"\n",
    "    )\n",
    "\n",
    "\n",
    "async def get_checkpointer() -> tuple:\n",
    "    settings = get_config()\n",
    "\n",
    "    if settings.checkpointer.name == \"postgres\":\n",
    "        print(f\"Using {settings.checkpointer.name} hosted on {settings.checkpointer.url} for checkpointer\")\n",
    "        db_user = os.environ.get(\"POSTGRES_USER\")\n",
    "        db_password = os.environ.get(\"POSTGRES_PASSWORD\")\n",
    "        db_name = os.environ.get(\"POSTGRES_DB\")\n",
    "        db_uri = f\"postgresql://{db_user}:{db_password}@{settings.checkpointer.url}/{db_name}?sslmode=disable\"\n",
    "        connection_kwargs = {\n",
    "            \"autocommit\": True,\n",
    "            \"prepare_threshold\": 0,\n",
    "            \"row_factory\": dict_row,\n",
    "        }\n",
    "\n",
    "        # Initialize PostgreSQL checkpointer\n",
    "        pool = AsyncConnectionPool(\n",
    "            conninfo=db_uri,\n",
    "            min_size=2,\n",
    "            kwargs=connection_kwargs,\n",
    "        )\n",
    "        checkpointer = AsyncPostgresSaver(pool)\n",
    "        await checkpointer.setup()\n",
    "        return checkpointer, pool\n",
    "    elif settings.checkpointer.name == \"inmemory\":\n",
    "        print(f\"Using MemorySaver as checkpointer\")\n",
    "        return MemorySaver(), None\n",
    "    else:\n",
    "        raise ValueError(f\"Only inmemory and postgres is supported chckpointer type\")\n",
    "\n",
    "\n",
    "def remove_state_from_checkpointer(session_id):\n",
    "\n",
    "    settings = get_config()\n",
    "    if settings.checkpointer.name == \"postgres\":\n",
    "        # Handle cleanup for PostgreSQL checkpointer\n",
    "        # Currently, there is no langgraph checkpointer API to remove data directly.\n",
    "        # The following tables are involved in storing checkpoint data:\n",
    "        # - checkpoint_blobs\n",
    "        # - checkpoint_writes\n",
    "        # - checkpoints\n",
    "        # Note: checkpoint_migrations table can be skipped for deletion.\n",
    "        try:\n",
    "            app_database_url = settings.checkpointer.url\n",
    "\n",
    "            # Parse the URL\n",
    "            parsed_url = urlparse(f\"//{app_database_url}\", scheme='postgres')\n",
    "\n",
    "            # Extract host and port\n",
    "            host = parsed_url.hostname\n",
    "            port = parsed_url.port\n",
    "\n",
    "            # Connect to your PostgreSQL database\n",
    "            connection = psycopg2.connect(\n",
    "                dbname=os.getenv('POSTGRES_DB', None),\n",
    "                user=os.getenv('POSTGRES_USER', None),\n",
    "                password=os.getenv('POSTGRES_PASSWORD', None),\n",
    "                host=host,\n",
    "                port=port\n",
    "            )\n",
    "            cursor = connection.cursor()\n",
    "\n",
    "            # Execute delete commands\n",
    "            cursor.execute(\"DELETE FROM checkpoint_blobs WHERE thread_id = %s\", (session_id,))\n",
    "            cursor.execute(\"DELETE FROM checkpoint_writes WHERE thread_id = %s\", (session_id,))\n",
    "            cursor.execute(\"DELETE FROM checkpoints WHERE thread_id = %s\", (session_id,))\n",
    "\n",
    "            # Commit the changes\n",
    "            connection.commit()\n",
    "            logger.info(f\"Deleted rows with thread_id: {session_id}\")\n",
    "\n",
    "        except Exception as e:\n",
    "            logger.info(f\"Error occurred while deleting data from checkpointer: {e}\")\n",
    "            # Optionally rollback if needed\n",
    "            if connection:\n",
    "                connection.rollback()\n",
    "        finally:\n",
    "            # Close the cursor and connection\n",
    "            if cursor:\n",
    "                cursor.close()\n",
    "            if connection:\n",
    "                connection.close()\n",
    "    else:\n",
    "        # For other supported checkpointer(i.e. inmemory) we don't need cleanup\n",
    "        pass\n",
    "\n",
    "def canonical_rag(query: str, conv_history: list)  -> str:\n",
    "    \"\"\"Use this for answering generic queries about products, specifications, warranties, usage, and issues.\"\"\"\n",
    "\n",
    "    entry_doc_search = {\"query\": query, \"top_k\": 4, \"conv_history\": conv_history}\n",
    "    response = requests.post(canonical_rag_search, json=entry_doc_search).json()\n",
    "\n",
    "    # Extract and aggregate the content\n",
    "    aggregated_content = \"\\n\".join(chunk[\"content\"] for chunk in response.get(\"chunks\", []))\n",
    "\n",
    "    return aggregated_content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d4499af",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile src/agent/main.py\n",
    "# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n",
    "# SPDX-License-Identifier: Apache-2.0\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "import logging\n",
    "import os\n",
    "from typing import Annotated, TypedDict, Dict\n",
    "from langgraph.graph.message import AnyMessage, add_messages\n",
    "from typing import Callable\n",
    "from langchain_core.messages import ToolMessage, AIMessage, HumanMessage, SystemMessage\n",
    "from typing import Annotated, Optional, Literal, TypedDict\n",
    "from langchain_core.prompts.chat import ChatPromptTemplate\n",
    "from langchain_core.prompts import MessagesPlaceholder\n",
    "from langgraph.graph import END, StateGraph, START\n",
    "from langgraph.prebuilt import tools_condition\n",
    "from langchain_core.runnables import RunnableConfig\n",
    "from src.agent.tools import (\n",
    "        structured_rag, get_purchase_history, HandleOtherTalk, ProductValidation,\n",
    "        return_window_validation, update_return, get_recent_return_details,\n",
    "        ToProductQAAssistant,\n",
    "        ToOrderStatusAssistant,\n",
    "        ToReturnProcessing)\n",
    "from src.agent.utils import get_product_name, create_tool_node_with_fallback, get_checkpointer, canonical_rag\n",
    "from src.common.utils import get_llm, get_prompts\n",
    "from nemoguardrails import RailsConfig\n",
    "from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "prompts = get_prompts()\n",
    "# TODO get the default_kwargs from the Agent Server API\n",
    "default_llm_kwargs = {\"temperature\": 0.2, \"top_p\": 0.7, \"max_tokens\": 1024}\n",
    "\n",
    "# Initialize  guardrails configuration\n",
    "rail_config = RailsConfig.from_path(\"./config\")\n",
    "guardrails = RunnableRails(rail_config, input_key=\"query\", output_key='content')\n",
    "\n",
    "# STATE OF THE AGENT\n",
    "class State(TypedDict):\n",
    "    messages: Annotated[list[AnyMessage], add_messages]\n",
    "    user_id: str\n",
    "    user_purchase_history: Dict\n",
    "    current_product: str\n",
    "    needs_clarification: bool\n",
    "    clarification_type: str\n",
    "    reason: str\n",
    "\n",
    "# NODES FOR THE AGENT\n",
    "def validate_product_info(state: State, config: RunnableConfig):\n",
    "    # This node will take user history and find product name based on query\n",
    "    # If there are multiple name of no name specified in the graph then it will\n",
    "\n",
    "    # This dict is to populate the user_purchase_history and product details if required\n",
    "    response_dict = {\"needs_clarification\": False}\n",
    "    if state[\"user_id\"]:\n",
    "        # Update user purchase history based\n",
    "        response_dict.update({\"user_purchase_history\": get_purchase_history(state[\"user_id\"])})\n",
    "\n",
    "        # Extracting product name which user is expecting\n",
    "        product_list = list(set([resp.get(\"product_name\") for resp in response_dict.get(\"user_purchase_history\", [])]))\n",
    "\n",
    "        # Extract product name from query and filter from database\n",
    "        product_info = get_product_name(state[\"messages\"], product_list)\n",
    "\n",
    "        product_names = product_info.get(\"products_from_purchase\", [])\n",
    "        product_in_query = product_info.get(\"product_in_query\", \"\")\n",
    "        if len(product_names) == 0:\n",
    "            reason = \"\"\n",
    "            if product_in_query:\n",
    "                reason = f\"{product_in_query}\"\n",
    "            response_dict.update({\"needs_clarification\": True, \"clarification_type\": \"no_product\", \"reason\": reason})\n",
    "            return response_dict\n",
    "        elif len(product_names) > 1:\n",
    "            reason = \", \".join(product_names)\n",
    "            response_dict.update({\"needs_clarification\": True, \"clarification_type\": \"multiple_products\", \"reason\": reason})\n",
    "            return response_dict\n",
    "        else:\n",
    "            response_dict.update({\"current_product\": product_names[0]})\n",
    "\n",
    "    return response_dict\n",
    "\n",
    "async def handle_other_talk(state: State, config: RunnableConfig):\n",
    "    \"\"\"Handles greetings and queries outside order status, returns, or products, providing polite redirection and explaining chatbot limitations.\"\"\"\n",
    "\n",
    "    prompt = prompts.get(\"other_talk_template\", \"\")\n",
    "\n",
    "    prompt = ChatPromptTemplate.from_messages(\n",
    "        [\n",
    "        (\"system\", prompt),\n",
    "        (\"placeholder\", \"{messages}\"),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    # LLM\n",
    "    llm_settings = config.get('configurable', {}).get(\"llm_settings\", default_llm_kwargs)\n",
    "    llm = get_llm(**llm_settings)\n",
    "    llm = llm.with_config(tags=[\"should_stream\"])\n",
    "\n",
    "    # Chain\n",
    "    small_talk_chain = prompt | llm\n",
    "    \n",
    "    # Adding guardrails\n",
    "    small_talk_chain_guardrails = guardrails | small_talk_chain\n",
    "    response = await small_talk_chain_guardrails.ainvoke(state, config)\n",
    "\n",
    "    return {\"messages\": [response.content]}\n",
    "\n",
    "\n",
    "def create_entry_node(assistant_name: str) -> Callable:\n",
    "    def entry_node(state: State) -> dict:\n",
    "        tool_call_id = state[\"messages\"][-1].tool_calls[0][\"id\"]\n",
    "        return {\n",
    "            \"messages\": [\n",
    "                ToolMessage(\n",
    "                    content=f\"The assistant is now the {assistant_name}. Reflect on the above conversation between the host assistant and the user.\"\n",
    "                    f\" The user's intent is unsatisfied. Use the provided tools to assist the user. Remember, you are {assistant_name},\"\n",
    "                    \" and the booking, update, other other action is not complete until after you have successfully invoked the appropriate tool.\"\n",
    "                    \" If the user changes their mind or needs help for other tasks, let the primary host assistant take control.\"\n",
    "                    \" Do not mention who you are - just act as the proxy for the assistant.\",\n",
    "                    tool_call_id=tool_call_id,\n",
    "                )\n",
    "            ]\n",
    "        }\n",
    "\n",
    "    return entry_node\n",
    "\n",
    "async def ask_clarification(state: State, config: RunnableConfig):\n",
    "\n",
    "    # Extract the base prompt\n",
    "    base_prompt = prompts.get(\"ask_clarification\")[\"base_prompt\"]\n",
    "    previous_conversation = [m for m in state['messages'] if not isinstance(m, ToolMessage)]\n",
    "    base_prompt = base_prompt.format(previous_conversation=previous_conversation)\n",
    "\n",
    "    purchase_history = state.get(\"user_purchase_history\", [])\n",
    "    if state[\"clarification_type\"] == \"no_product\" and state['reason'].strip():\n",
    "        followup_prompt = prompts.get(\"ask_clarification\")[\"followup\"][\"no_product\"].format(\n",
    "            reason=state['reason'],\n",
    "            purchase_history=purchase_history\n",
    "        )\n",
    "    elif not state['reason'].strip():\n",
    "        followup_prompt = prompts.get(\"ask_clarification\")[\"followup\"][\"default\"].format(reason=purchase_history)\n",
    "    else:\n",
    "        followup_prompt = prompts.get(\"ask_clarification\")[\"followup\"][\"default\"].format(reason=state['reason'])\n",
    "\n",
    "    # Combine base prompt and followup prompt\n",
    "    prompt = f\"{base_prompt} {followup_prompt}\"\n",
    "\n",
    "    # LLM\n",
    "    llm_settings = config.get('configurable', {}).get(\"llm_settings\", default_llm_kwargs)\n",
    "    llm = get_llm(**llm_settings)\n",
    "    llm = llm.with_config(tags=[\"should_stream\"])\n",
    "    \n",
    "    # Adding the guardrails\n",
    "    chain_with guardrails = guardrails | llm\n",
    "\n",
    "    response = await chain_with_guardrails.ainvoke(prompt, config)\n",
    "\n",
    "    return {\"messages\": [response.content]}\n",
    "\n",
    "async def handle_product_qa(state: State, config: RunnableConfig):\n",
    "\n",
    "    # Extract the previous_conversation\n",
    "    previous_conversation = [m for m in state['messages'] if not isinstance(m, ToolMessage) and m.content]\n",
    "    message_type_map = {\n",
    "        HumanMessage: \"user\",\n",
    "        AIMessage: \"assistant\",\n",
    "        SystemMessage: \"system\"\n",
    "    }\n",
    "\n",
    "    # Serialized conversation\n",
    "    get_role = lambda x: message_type_map.get(type(x), None)\n",
    "    previous_conversation_serialized = [{\"role\": get_role(m), \"content\": m.content} for m in previous_conversation if m.content]\n",
    "    last_message = previous_conversation_serialized[-1]['content']\n",
    "\n",
    "    retireved_content = canonical_rag(query=last_message, conv_history=previous_conversation_serialized)\n",
    "\n",
    "    # Use the RAG Template to generate the response\n",
    "    base_rag_prompt = prompts.get(\"rag_template\")\n",
    "    rag_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\"system\", base_rag_prompt),\n",
    "        MessagesPlaceholder(\"chat_history\") + \"\\n\\nCONTEXT:  {context}\"\n",
    "    ]\n",
    "    )\n",
    "    rag_prompt = rag_prompt.format(chat_history=previous_conversation, context=retireved_content)\n",
    "\n",
    "    # LLM\n",
    "    llm_settings = config.get('configurable', {}).get(\"llm_settings\", default_llm_kwargs)\n",
    "    llm = get_llm(**llm_settings)\n",
    "    llm = llm.with_config(tags=[\"should_stream\"])\n",
    "    \n",
    "    # Adding guardrails\n",
    "    chain_with_guardrails = guardrails | llm\n",
    "\n",
    "    response = await chain_with_guardrails.ainvoke(rag_prompt, config)\n",
    "\n",
    "    return {\"messages\": [response.content]}\n",
    "\n",
    "class Assistant:\n",
    "    def __init__(self, prompt: str, tools: list):\n",
    "        self.prompt = prompt\n",
    "        self.tools = tools\n",
    "\n",
    "    async def __call__(self, state: State, config: RunnableConfig):\n",
    "        while True:\n",
    "\n",
    "            llm_settings = config.get('configurable', {}).get(\"llm_settings\", default_llm_kwargs)\n",
    "            llm = get_llm(**llm_settings)\n",
    "            runnable = self.prompt | llm.bind_tools(self.tools)\n",
    "            runnable_with_guardrails = guardrails | runnable\n",
    "            state = await runnable_with_guardrails.invoke(state)\n",
    "            last_message = state[\"messages\"][-1]\n",
    "            messages = []\n",
    "            if isinstance(last_message, ToolMessage) and last_message.name in [\"structured_rag\", \"return_window_validation\", \"update_return\", \"get_purchase_history\", \"get_recent_return_details\"]:\n",
    "                gen = runnable.with_config(\n",
    "                tags=[\"should_stream\"],\n",
    "                callbacks=config.get(\n",
    "                    \"callbacks\", []\n",
    "                ),  # <-- Propagate callbacks (Python <= 3.10)\n",
    "                )\n",
    "                async for message in gen.astream(state):\n",
    "                    messages.append(message.content)\n",
    "                result = AIMessage(content=\"\".join(messages))\n",
    "            else:\n",
    "                result = runnable_with_guardrails.invoke(state)\n",
    "\n",
    "            if not result.tool_calls and (\n",
    "                not result.content\n",
    "                or isinstance(result.content, list)\n",
    "                and not result.content[0].get(\"text\")\n",
    "            ):\n",
    "                messages = state[\"messages\"] + [(\"user\", \"Respond with a real output.\")]\n",
    "                state = {**state, \"messages\": messages}\n",
    "                messages = state[\"messages\"] + [(\"user\", \"Respond with a real output.\")]\n",
    "                state = {**state, \"messages\": messages}\n",
    "            else:\n",
    "                break\n",
    "        return {\"messages\": result}\n",
    "\n",
    "# order status Assistant\n",
    "order_status_prompt_template = prompts.get(\"order_status_template\", \"\")\n",
    "\n",
    "order_status_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            order_status_prompt_template\n",
    "        ),\n",
    "        (\"placeholder\", \"{messages}\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "order_status_safe_tools = [structured_rag]\n",
    "order_status_tools = order_status_safe_tools + [ProductValidation]\n",
    "\n",
    "# Return Processing Assistant\n",
    "return_processing_prompt_template = prompts.get(\"return_processing_template\", \"\")\n",
    "\n",
    "return_processing_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            return_processing_prompt_template\n",
    "        ),\n",
    "        (\"placeholder\", \"{messages}\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "return_processing_safe_tools = [get_recent_return_details, return_window_validation]\n",
    "return_processing_sensitive_tools = [update_return]\n",
    "return_processing_tools = return_processing_safe_tools + return_processing_sensitive_tools + [ProductValidation]\n",
    "\n",
    "primary_assistant_prompt_template = prompts.get(\"primary_assistant_template\", \"\")\n",
    "\n",
    "primary_assistant_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            primary_assistant_prompt_template\n",
    "        ),\n",
    "        (\"placeholder\", \"{messages}\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "primary_assistant_tools = [\n",
    "        HandleOtherTalk,\n",
    "        ToProductQAAssistant,\n",
    "        ToOrderStatusAssistant,\n",
    "        ToReturnProcessing,\n",
    "    ]\n",
    "\n",
    "# BUILD THE GRAPH\n",
    "builder = StateGraph(State)\n",
    "\n",
    "\n",
    "# SUB AGENTS\n",
    "# Create product_qa Assistant\n",
    "builder.add_node(\n",
    "    \"enter_product_qa\",\n",
    "    handle_product_qa,\n",
    ")\n",
    "\n",
    "builder.add_edge(\"enter_product_qa\", END)\n",
    "\n",
    "builder.add_node(\"order_validation\", validate_product_info)\n",
    "builder.add_node(\"ask_clarification\", ask_clarification)\n",
    "\n",
    "# Create order_status Assistant\n",
    "builder.add_node(\n",
    "    \"enter_order_status\", create_entry_node(\"Order Status Assistant\")\n",
    ")\n",
    "builder.add_node(\"order_status\", Assistant(order_status_prompt, order_status_tools))\n",
    "builder.add_edge(\"enter_order_status\", \"order_status\")\n",
    "builder.add_node(\n",
    "    \"order_status_safe_tools\",\n",
    "    create_tool_node_with_fallback(order_status_safe_tools),\n",
    ")\n",
    "\n",
    "\n",
    "def route_order_status(\n",
    "    state: State,\n",
    ") -> Literal[\n",
    "    \"order_status_safe_tools\",\n",
    "    \"order_validation\",\n",
    "    \"__end__\"\n",
    "]:\n",
    "    route = tools_condition(state)\n",
    "    if route == END:\n",
    "        return END\n",
    "    tool_calls = state[\"messages\"][-1].tool_calls\n",
    "    tool_names = [t.name for t in order_status_safe_tools]\n",
    "    do_product_validation = any(tc[\"name\"] == ProductValidation.__name__ for tc in tool_calls)\n",
    "    if do_product_validation:\n",
    "        return \"order_validation\"\n",
    "    if all(tc[\"name\"] in tool_names for tc in tool_calls):\n",
    "        return \"order_status_safe_tools\"\n",
    "    return \"order_status_sensitive_tools\"\n",
    "\n",
    "builder.add_edge(\"order_status_safe_tools\", \"order_status\")\n",
    "builder.add_conditional_edges(\"order_status\", route_order_status)\n",
    "\n",
    "# Create return_processing Assistant\n",
    "builder.add_node(\"return_validation\", validate_product_info)\n",
    "\n",
    "builder.add_node(\n",
    "    \"enter_return_processing\",\n",
    "    create_entry_node(\"Return Processing Assistant\"),\n",
    ")\n",
    "builder.add_node(\"return_processing\", Assistant(return_processing_prompt, return_processing_tools))\n",
    "builder.add_edge(\"enter_return_processing\", \"return_processing\")\n",
    "\n",
    "builder.add_node(\n",
    "    \"return_processing_safe_tools\",\n",
    "    create_tool_node_with_fallback(return_processing_safe_tools),\n",
    ")\n",
    "builder.add_node(\n",
    "    \"return_processing_sensitive_tools\",\n",
    "    create_tool_node_with_fallback(return_processing_sensitive_tools),\n",
    ")\n",
    "\n",
    "\n",
    "def route_return_processing(\n",
    "    state: State,\n",
    ") -> Literal[\n",
    "    \"return_processing_safe_tools\",\n",
    "    \"return_processing_sensitive_tools\",\n",
    "    \"return_validation\",\n",
    "    \"__end__\",\n",
    "]:\n",
    "    route = tools_condition(state)\n",
    "    if route == END:\n",
    "        return END\n",
    "    tool_calls = state[\"messages\"][-1].tool_calls\n",
    "    do_product_validation = any(tc[\"name\"] == ProductValidation.__name__ for tc in tool_calls)\n",
    "    if do_product_validation:\n",
    "        return \"return_validation\"\n",
    "    tool_names = [t.name for t in return_processing_safe_tools]\n",
    "    if all(tc[\"name\"] in tool_names for tc in tool_calls):\n",
    "        return \"return_processing_safe_tools\"\n",
    "    return \"return_processing_sensitive_tools\"\n",
    "\n",
    "\n",
    "builder.add_edge(\"return_processing_sensitive_tools\", \"return_processing\")\n",
    "builder.add_edge(\"return_processing_safe_tools\", \"return_processing\")\n",
    "builder.add_conditional_edges(\"return_processing\", route_return_processing)\n",
    "\n",
    "\n",
    "def user_info(state: State):\n",
    "    return {\"user_purchase_history\": get_purchase_history(state[\"user_id\"]), \"current_product\": \"\"}\n",
    "\n",
    "builder.add_node(\"fetch_purchase_history\", user_info)\n",
    "builder.add_edge(START, \"fetch_purchase_history\")\n",
    "builder.add_edge(\"ask_clarification\", END)\n",
    "\n",
    "# Primary assistant\n",
    "builder.add_node(\"primary_assistant\", Assistant(primary_assistant_prompt, primary_assistant_tools))\n",
    "builder.add_node(\n",
    "    \"other_talk\", handle_other_talk\n",
    ")\n",
    "\n",
    "#  Add \"primary_assistant_tools\", if necessary\n",
    "def route_primary_assistant(\n",
    "    state: State,\n",
    ") -> Literal[\n",
    "    \"enter_product_qa\",\n",
    "    \"enter_order_status\",\n",
    "    \"enter_return_processing\",\n",
    "    \"other_talk\",\n",
    "    \"__end__\",\n",
    "]:\n",
    "    route = tools_condition(state)\n",
    "    if route == END:\n",
    "        return END\n",
    "    tool_calls = state[\"messages\"][-1].tool_calls\n",
    "    if tool_calls:\n",
    "        if tool_calls[0][\"name\"] == ToProductQAAssistant.__name__:\n",
    "            return \"enter_product_qa\"\n",
    "        elif tool_calls[0][\"name\"] == ToOrderStatusAssistant.__name__:\n",
    "            return \"enter_order_status\"\n",
    "        elif tool_calls[0][\"name\"] == ToReturnProcessing.__name__:\n",
    "            return \"enter_return_processing\"\n",
    "        elif tool_calls[0][\"name\"] == HandleOtherTalk.__name__:\n",
    "            return \"other_talk\"\n",
    "    raise ValueError(\"Invalid route\")\n",
    "\n",
    "builder.add_edge(\"other_talk\", END)\n",
    "\n",
    "# The assistant can route to one of the delegated assistants,\n",
    "# directly use a tool, or directly respond to the user\n",
    "builder.add_conditional_edges(\n",
    "    \"primary_assistant\",\n",
    "    route_primary_assistant,\n",
    "    {\n",
    "        \"enter_product_qa\": \"enter_product_qa\",\n",
    "        \"enter_order_status\": \"enter_order_status\",\n",
    "        \"enter_return_processing\": \"enter_return_processing\",\n",
    "        \"other_talk\":\"other_talk\",\n",
    "        END: END,\n",
    "    },\n",
    ")\n",
    "\n",
    "\n",
    "def is_order_product_valid(state: State)  -> Literal[\n",
    "    \"ask_clarification\",\n",
    "    \"order_status\"\n",
    "]:\n",
    "    \"\"\"Conditional edge from validation node to decide if we should ask followup questions\"\"\"\n",
    "    if state[\"needs_clarification\"] == True:\n",
    "        return \"ask_clarification\"\n",
    "    return \"order_status\"\n",
    "\n",
    "def is_return_product_valid(state: State)  -> Literal[\n",
    "    \"ask_clarification\",\n",
    "    \"return_processing\"\n",
    "]:\n",
    "    \"\"\"Conditional edge from validation node to decide if we should ask followup questions\"\"\"\n",
    "    if state[\"needs_clarification\"] == True:\n",
    "        return \"ask_clarification\"\n",
    "    return \"return_processing\"\n",
    "\n",
    "builder.add_conditional_edges(\n",
    "    \"order_validation\",\n",
    "    is_order_product_valid\n",
    ")\n",
    "builder.add_conditional_edges(\n",
    "    \"return_validation\",\n",
    "    is_return_product_valid\n",
    ")\n",
    "\n",
    "builder.add_edge(\"fetch_purchase_history\", \"primary_assistant\")\n",
    "\n",
    "\n",
    "# Allow multiple async loop togeather\n",
    "# This is needed to create checkpoint as it needs async event loop\n",
    "# TODO: Move graph build into a async function and call that to remove nest_asyncio\n",
    "import nest_asyncio\n",
    "nest_asyncio.apply()\n",
    "\n",
    "# To run the async main function\n",
    "import asyncio\n",
    "\n",
    "memory = None\n",
    "pool = None\n",
    "\n",
    "# TODO: Remove pool as it's not getting used\n",
    "# WAR: It's added so postgres does not close it's session\n",
    "async def get_checkpoint():\n",
    "    global memory, pool\n",
    "    memory, pool = await get_checkpointer()\n",
    "\n",
    "asyncio.run(get_checkpoint())\n",
    "\n",
    "# Compile\n",
    "graph = builder.compile(checkpointer=memory,\n",
    "                        interrupt_before=[\"return_processing_sensitive_tools\"],\n",
    "                        #interrupt_after=[\"ask_human\"]\n",
    "                        )\n",
    "\n",
    "try:\n",
    "    # Generate the PNG image from the graph\n",
    "    png_image_data = graph.get_graph(xray=True).draw_mermaid_png()\n",
    "    # Save the image to a file in the current directory\n",
    "    with open(\"graph_image.png\", \"wb\") as f:\n",
    "        f.write(png_image_data)\n",
    "except Exception as e:\n",
    "    # This requires some extra dependencies and is optional\n",
    "    logger.info(f\"An error occurred: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a1920a3",
   "metadata": {},
   "source": [
    "with the guardrails configuration built and wrapped around the agent, we will run the nemoguardrails server. Make sure to add the absolute path of the `config` directory and the `container image` and run the following cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6910e96",
   "metadata": {
    "id": "b618cbbc-4e6c-44db-85d0-b5dbe4617f33"
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "docker run -p 8000:8000 -v </path/to/local/config/>:/config <IMAGE_NAME>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82ee29a0",
   "metadata": {
    "id": "74eeec5a"
   },
   "source": [
    "## Exposing the Interface for Testing (optional)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34a9104e",
   "metadata": {
    "id": "4f76c683"
   },
   "source": [
    "The Blueprint comes equiped with a basic UI for testing the deployment. This interface is served at port 3001. In order to expose the port and try out the interaction, you need to follow the steps below.\n",
    "\n",
    "First, navigate back to the created Launchable instance page and click on the Access menu.\n",
    "\n",
    "\n",
    "![Access Menu](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/raw/main/docs/imgs/brev-cli-install.png)\n",
    "\n",
    "\n",
    "Scroll down until you find \"Using Tunnels\" section and click on Share a Service button.\n",
    "\n",
    "\n",
    "![Using Tunnels](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/raw/main/docs/imgs/brev-tunnels.png)\n",
    "\n",
    "\n",
    "Enter the port 3001, as that is where the UI service endpoint is. Confirm with Done. Then click on Edit Access and make the port public:\n",
    "\n",
    "\n",
    "![Share Access](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/raw/main/docs/imgs/brev-share-access.png)\n",
    "\n",
    "\n",
    "Past this point, by clicking on the link, the UI should appear in your browser and you are free to interact with the assistant and to ask him about the data that was ingested.\n",
    "\n",
    "\n",
    "![AI Virtual Assistant Interface](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/raw/main/docs/imgs/ai-virtual-assistant-interface.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "262835e1",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
